{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The visualization used for this seminar is based on Alexandr Verinov's code.  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generative models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this seminar we will try several criterions for learning an implicit model. For the first part almost everything is written for you, and you only need to implement the objective for the game and play around with the model. \n",
    "\n",
    "**0)** Read the code\n",
    "\n",
    "**1)** Implement objective for a vanilla [Generative Adversarial Networks](https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf) (GAN). The hyperparameters are already set in the code. The model will converge if you implement the objective (1) right. \n",
    "\n",
    "**2)** Note the discussion in the paper, that the objective for $G$ can be of two kinds: $\\min_G \\log(1 - D)$ and $\\min_G - \\log(D)$. Now implement the second objective and ensure model converges. Most likely, in this example you will not notice the difference, but people usually use the second objective, it really matters in more complicated scenarios. **NOTE:** the objective for D stays the same.\n",
    "\n",
    "**3 & 4)** Implement [Wasserstein GAN](https://arxiv.org/pdf/1701.07875.pdf) and WGAN-GP. To make the discriminator has Lipschitz property you need to clip discriminator's weights to $[-0.01, 0.01]$ range (WGAN) or use gradient penalty (WGAN-GP). You will need to make few modifications to the code: \n",
    "\n",
    "   - Remove sigmoids from discriminator;\n",
    "   - Change objective (see eq. 3 and algorithm 1 in [the paper](https://arxiv.org/pdf/1701.07875.pdf)): \n",
    "   - Add weight clipping for D [see here](https://github.com/martinarjovsky/WassersteinGAN/blob/master/main.py#L172) / gradient penaly (WGAN-GP) [code](https://gist.github.com/DmitryUlyanov/19ce84045135e3f81a477629e685aec8); \n",
    "\n",
    "   \n",
    "In general see [implementation 1](https://github.com/martinarjovsky/WassersteinGAN/blob/master/main.py#L172) / [implementation 2](https://github.com/caogang/wgan-gp). They also use different optimizer. \n",
    "\n",
    "The default hyperparameters may not work well, spend some time to tune them -- play with learning rate, number of D updates per one G update, change architecture (what about weight initialization?). \n",
    "\n",
    "**5) Bonus: Wasserstein Introspective Neural Networkss**. This is basically WGAN-GP without generator. Read and implement [WINN paper](https://arxiv.org/pdf/1711.08875.pdf) for our toy task. The classification step is almost identical to the discriminative step for WGAN-GP. However on synthesis step, we will not use a generator network, but instead we optimize the same loss as the generator loss in WGAN-GP with respect to the *generated objects* (aka \"pseudo-negative samples\"). Then, we accumulate the generated \"pseudo-negative\" samples and use mini-batches from them as the \"fake data\" for the next classification step.\n",
    "\n",
    "Here are some tips for you:\n",
    "- Initialize your \"fake dataset\" with random noise.\n",
    "- During the classification stage, sample fake data from the fake dataset.\n",
    "- For the synthesis step, use the fake samples from the previous step as the initial value.\n",
    "- You can use an ordinary Adam optimizer to update the samples, but you need to inject small noise on each step (last equation on page 4). Do not forget to early stop after the threshold (page 5, first paragraph) is reached.\n",
    "- Add the new generated points to the \"fake dataset\".\n",
    "\n",
    "To make the visualization work without a generator, you have to supply your generated samples to `vis_points` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" \n",
    "    Please, implement everything in one notebook, using if statements to switch between the tasks\n",
    "\"\"\"\n",
    "# TASK in [1, 2, 3, 4, 5]\n",
    "TASK = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "torch.set_num_threads(4)\n",
    "np.random.seed(12345)\n",
    "lims=(-5, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define sampler from real data and Z "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some utility functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import rv_discrete\n",
    "\n",
    "MEANS = np.array(\n",
    "        [[-1,-3],\n",
    "         [1,3],\n",
    "         [-2,0],\n",
    "        ])\n",
    "COVS = np.array(\n",
    "        [[[1,0.8],[0.8,1]],\n",
    "        [[1,-0.5],[-0.5,1]],\n",
    "        [[1,0],[0,1]],\n",
    "        ])\n",
    "PROBS = np.array([\n",
    "        0.2,\n",
    "        0.5,\n",
    "        0.3\n",
    "        ])\n",
    "assert len(MEANS) == len(COVS) == len(PROBS), \"number of components mismatch\"\n",
    "COMPONENTS = len(MEANS)\n",
    "\n",
    "comps_dist = rv_discrete(values=(range(COMPONENTS), PROBS))\n",
    "\n",
    "def sample_true(N):\n",
    "    comps = comps_dist.rvs(size=N)\n",
    "    conds = np.arange(COMPONENTS)[:,None] == comps[None,:]\n",
    "    arr = np.array([np.random.multivariate_normal(MEANS[c], COVS[c], size=N)\n",
    "                     for c in range(COMPONENTS)])\n",
    "    return np.select(conds[:,:,None], arr).astype(np.float32)\n",
    "\n",
    "NOISE_DIM = 20\n",
    "def sample_noise(N):\n",
    "    return np.random.normal(size=(N,NOISE_DIM)).astype(np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualization functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And more utility functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vis_data(data):\n",
    "    \"\"\"\n",
    "        Visualizes data as histogram\n",
    "    \"\"\"\n",
    "    hist = np.histogram2d(data[:, 1], data[:, 0], bins=100, range=[lims, lims])\n",
    "    plt.pcolormesh(hist[1], hist[2], hist[0], alpha=0.5)\n",
    "\n",
    "fixed_noise = torch.Tensor(sample_noise(1000))\n",
    "def vis_g():\n",
    "    \"\"\"\n",
    "        Visualizes generator's samples as circles\n",
    "    \"\"\"\n",
    "    data = generator(fixed_noise).data.numpy()\n",
    "    if np.isnan(data).any():\n",
    "        return\n",
    "    \n",
    "    plt.scatter(data[:,0], data[:,1], alpha=0.2, c='b')\n",
    "    plt.xlim(lims)\n",
    "    plt.ylim(lims)\n",
    "    \n",
    "    \n",
    "def vis_points(data):\n",
    "    \"\"\"\n",
    "        Visualizes the supplied samples as circles\n",
    "    \"\"\"\n",
    "    if np.isnan(data).any():\n",
    "        return\n",
    "    \n",
    "    plt.scatter(data[:,0], data[:,1], alpha=0.2, c='b')\n",
    "    plt.xlim(lims)\n",
    "    plt.ylim(lims)\n",
    "    \n",
    "\n",
    "def get_grid():\n",
    "    X, Y = np.meshgrid(np.linspace(lims[0], lims[1], 30), np.linspace(lims[0], lims[1], 30))\n",
    "    X = X.flatten()\n",
    "    Y = Y.flatten()\n",
    "        \n",
    "    grid = torch.from_numpy(np.vstack([X, Y]).astype(np.float32).T)\n",
    "    grid.requires_grad = True\n",
    "                            \n",
    "    return X, Y, grid\n",
    "              \n",
    "X_grid, Y_grid, grid = get_grid()\n",
    "def vis_d():\n",
    "    \"\"\"\n",
    "        Visualizes discriminator's gradient on grid\n",
    "    \"\"\"\n",
    "        \n",
    "    data_gen = generator(fixed_noise)\n",
    "#     loss = d_loss(discriminator(data_gen), discriminator(grid))\n",
    "    loss = g_loss(discriminator(grid))\n",
    "    loss.backward()\n",
    "    \n",
    "    grads = - grid.grad.data.numpy()\n",
    "    grid.grad.data *= 0 \n",
    "    plt.quiver(X_grid, Y_grid, grads[:, 0], grads[:, 1], color='black',alpha=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define architectures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After you've done with task 1 you can play with architectures."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_generator(noise_dim, out_dim, hidden_dim=100):\n",
    "    layers = [\n",
    "        nn.Linear(noise_dim, hidden_dim),\n",
    "        nn.LeakyReLU(),\n",
    "        nn.Linear(hidden_dim, hidden_dim),\n",
    "        nn.LeakyReLU(),\n",
    "        nn.Linear(hidden_dim, out_dim)\n",
    "    ]\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "def get_discriminator(in_dim, hidden_dim=100):\n",
    "    layers = [\n",
    "        nn.Linear(in_dim, hidden_dim),\n",
    "        nn.LeakyReLU(),\n",
    "        nn.Linear(hidden_dim, hidden_dim),\n",
    "        nn.LeakyReLU(),\n",
    "        nn.Linear(hidden_dim, hidden_dim),\n",
    "        nn.LeakyReLU(),\n",
    "        nn.Linear(hidden_dim, 1),\n",
    "        nn.Sigmoid()\n",
    "    ]\n",
    "        \n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define updates and losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = get_generator(NOISE_DIM, out_dim = 2)\n",
    "discriminator = get_discriminator(in_dim = 2)\n",
    "\n",
    "lr = 0.001\n",
    "g_optimizer = optim.Adam(generator.parameters(),     lr=lr, betas=(0.5, 0.999))\n",
    "d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice we are using ADAM optimizer with `beta1=0.5` for both discriminator and discriminator. This is a common practice and works well. Motivation: models should be flexible and adapt itself rapidly to the distributions. \n",
    "\n",
    "You can try different optimizers and parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################################\n",
    "#       IMPLEMENT HERE\n",
    "################################\n",
    "# Define the g_loss and d_loss here\n",
    "# these are the only lines of code you need to change to implement Tasks 1 and 2 \n",
    "\n",
    "def g_loss(d_scores_fake):\n",
    "    \"\"\"\n",
    "        `d_scores_fake` is the output of the discrimonator model applied to a batch of fake data\n",
    "        \n",
    "        NOTE: we always define objectives as if we were minimizing them (remember that maximize = negate and minimize)\n",
    "    \"\"\"\n",
    "    # if TASK == 1: \n",
    "    #     return something\n",
    "    #  elif TASK == 2:\n",
    "    #     return something else\n",
    "    \n",
    "    return # TODO\n",
    "    \n",
    "def d_loss(d_scores_fake, d_scores_real):\n",
    "    \"\"\"\n",
    "        `d_scores_fake` is the output of the discrimonator model applied to a batch of fake data\n",
    "        `d_scores_real` is the output of the discrimonator model applied to a batch of real data\n",
    "        \n",
    "        NOTE: we always define objectives as if we were minimizing them (remember that maximize = negate and minimize)\n",
    "    \"\"\"\n",
    "    # if TASK == 1: \n",
    "    #     return something\n",
    "    #  elif TASK == 2:\n",
    "    #     return something else\n",
    "    \n",
    "    return # TODO"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get real data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = sample_true(100000)\n",
    "def iterate_minibatches(X, batchsize, y=None):\n",
    "    perm = np.random.permutation(X.shape[0])\n",
    "    \n",
    "    for start in range(0, X.shape[0], batchsize):\n",
    "        end = min(start + batchsize, X.shape[0])\n",
    "        if y is None:\n",
    "            yield X[perm[start:end]]\n",
    "        else:\n",
    "            yield X[perm[start:end]], y[perm[start:end]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = (12, 12)\n",
    "vis_data(data)\n",
    "vis_g()\n",
    "vis_d()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Legend**:\n",
    "- Blue dots are generated samples. \n",
    "- Colored histogram at the back shows density of real data. \n",
    "- And with arrows we show gradients of the discriminator -- they are the directions that discriminator pushes generator's samples. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython import display\n",
    "\n",
    "plt.xlim(lims)\n",
    "plt.ylim(lims)\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "# ===========================\n",
    "# IMPORTANT PARAMETER:\n",
    "# Number of D updates per G update\n",
    "# ===========================\n",
    "k_d, k_g = 1, 1\n",
    "\n",
    "try:\n",
    "    for it, real_data in enumerate(iterate_minibatches(data, batch_size)):\n",
    "\n",
    "        # Optimize D\n",
    "        for _ in range(k_d):\n",
    "            d_optimizer.zero_grad()\n",
    "            \n",
    "            # Sample noise\n",
    "            noise = torch.Tensor(sample_noise(real_data.shape[0]))\n",
    "\n",
    "            # Compute gradient\n",
    "            real_data = torch.Tensor(real_data)\n",
    "            fake_data = generator(noise)\n",
    "            loss = d_loss(discriminator(fake_data), discriminator(real_data))            \n",
    "            loss.backward()\n",
    "            \n",
    "            # IMPLEMENT HERE GP FOR TASK 4\n",
    "                \n",
    "            # Update\n",
    "            d_optimizer.step()\n",
    "\n",
    "        # Optimize G\n",
    "        for _ in range(k_g):\n",
    "            g_optimizer.zero_grad()\n",
    "            \n",
    "            # Sample noise\n",
    "            noise = torch.Tensor(sample_noise(real_data.shape[0]))\n",
    "\n",
    "            # Compute gradient\n",
    "            fake_data = generator(noise)\n",
    "            loss = g_loss(discriminator(fake_data))\n",
    "            loss.backward()\n",
    "            \n",
    "            # Update\n",
    "            g_optimizer.step()\n",
    "\n",
    "        # Visualize\n",
    "        if it % 2 == 0:\n",
    "            plt.clf()\n",
    "            vis_data(data)\n",
    "            \n",
    "            if TASK < 5:\n",
    "                vis_g()\n",
    "            else:\n",
    "                # UNCOMMENT AND SUPPLY YOUR SAMPLES FOR BONUS TASK 5\n",
    "                # vis_points(generated_samples[-1000:])\n",
    "            vis_d()\n",
    "            display.clear_output(wait=True)\n",
    "            display.display(plt.gcf())\n",
    "            print(f\"Task {TASK}; Iteration {it}\")\n",
    "        \n",
    "except KeyboardInterrupt:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Describe your findings here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "London is the capital of Great Britain."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
